{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 解读 {class}`~tvm.target.generic_func.generic_func.GenericFunc`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm-book/doc/read\n"
     ]
    }
   ],
   "source": [
    "%cd ..\n",
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm.target.generic_func import GenericFunc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```c++\n",
    "/*!\n",
    " * \\brief Generate the strategy of operators. This function is a generic\n",
    " * function and can be re-defined for different targets.\n",
    " *\n",
    " * The function signature of generic function is:\n",
    " *   OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,\n",
    " *              const Type& out_type, const Target& target)\n",
    " */\n",
    "using FTVMStrategy = GenericFunc;\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了名为 `FTVMStrategy` 的类型别名，它表示通用函数。这个通用函数用于生成算子的策略。它的函数签名如下：\n",
    "\n",
    "```cpp\n",
    "OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,\n",
    "           const Type& out_type, const Target& target)\n",
    "```\n",
    "\n",
    "其中，`Attrs` 表示属性集合，`Array<Tensor>` 表示输入张量数组，`Type` 表示输出类型，`Target` 表示目标平台。这个函数可以根据不同的目标平台进行重定义。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m \u001b[0mregister_strategy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfstrategy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mSource:\u001b[0m   \n",
      "\u001b[0;32mdef\u001b[0m \u001b[0mregister_strategy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfstrategy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;34m\"\"\"Register strategy function for an op.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    Parameters\u001b[0m\n",
      "\u001b[0;34m    ----------\u001b[0m\n",
      "\u001b[0;34m    op_name : str\u001b[0m\n",
      "\u001b[0;34m        The name of the op.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    fstrategy : function (attrs: Attrs, inputs: List[Tensor], out_type: Type,\u001b[0m\n",
      "\u001b[0;34m                          target:Target) -> OpStrategy\u001b[0m\n",
      "\u001b[0;34m        The strategy function. Need to be native GenericFunc.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    level : int\u001b[0m\n",
      "\u001b[0;34m        The priority level\u001b[0m\n",
      "\u001b[0;34m    \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfstrategy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mGenericFunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m        \u001b[0;32massert\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfstrategy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"generic_func_node\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m        \u001b[0mfstrategy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgeneric_func_node\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;32mreturn\u001b[0m \u001b[0mtvm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mregister_op_attr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"FTVMStrategy\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfstrategy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFile:\u001b[0m      /media/pc/data/lxw/ai/tvm/python/tvm/relay/op/op.py\n",
      "\u001b[0;31mType:\u001b[0m      function"
     ]
    }
   ],
   "source": [
    "from tvm.relay.op.op import register_strategy\n",
    "\n",
    "register_strategy??"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```cpp\n",
    "class GenericFuncNode;\n",
    "\n",
    "/*!\n",
    " * \\brief Generic function that can be specialized on a per-target basis.\n",
    " */\n",
    "class GenericFunc : public ObjectRef {\n",
    " public:\n",
    "  GenericFunc() {}\n",
    "  explicit GenericFunc(ObjectPtr<Object> n) : ObjectRef(n) {}\n",
    "\n",
    "  /*!\n",
    "   * \\brief Set the default function implementaiton.\n",
    "   * \\param value The default function\n",
    "   * \\param allow_override If true, this call may override a previously registered function. If\n",
    "   * false, an error will be logged if the call would override a previously registered function.\n",
    "   * \\return reference to self.\n",
    "   */\n",
    "  TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false);\n",
    "  /*!\n",
    "   * \\brief Register a specialized function\n",
    "   * \\param tags The tags for this specialization\n",
    "   * \\param value The specialized function\n",
    "   * \\param allow_override If true, this call may override previously registered tags. If false,\n",
    "   * an error will be logged if the call would override previously registered tags.\n",
    "   * \\return reference to self.\n",
    "   */\n",
    "  TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,\n",
    "                                     const runtime::PackedFunc value, bool allow_override = false);\n",
    "  /*!\n",
    "   * \\brief Call generic function by directly passing in unpacked format.\n",
    "   * \\param args Arguments to be passed.\n",
    "   * \\tparam Args arguments to be passed.\n",
    "   *\n",
    "   * \\code\n",
    "   *   // Example code on how to call generic function\n",
    "   *   void CallGeneric(GenericFunc f) {\n",
    "   *     // call like normal functions by pass in arguments\n",
    "   *     // return value is automatically converted back\n",
    "   *     int rvalue = f(1, 2.0);\n",
    "   *   }\n",
    "   * \\endcode\n",
    "   */\n",
    "  template <typename... Args>\n",
    "  inline runtime::TVMRetValue operator()(Args&&... args) const;\n",
    "  /*!\n",
    "   * \\brief Invoke the relevant function for the current target context, set by set_target_context.\n",
    "   * Arguments are passed in packed format.\n",
    "   * \\param args The arguments to pass to the function.\n",
    "   * \\param ret The return value\n",
    "   */\n",
    "  TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const;\n",
    "  /*!\n",
    "   * \\brief Get the packed function specified for the current target context.\n",
    "   */\n",
    "  TVM_DLL PackedFunc GetPacked() const;\n",
    "  /*!\n",
    "   * \\brief Find or register the GenericFunc instance corresponding to the give name\n",
    "   * \\param name The name of the registered GenericFunc\n",
    "   * \\return The GenericFunc instance\n",
    "   */\n",
    "  TVM_DLL static GenericFunc Get(const std::string& name);\n",
    "\n",
    "  /*!\n",
    "   * \\brief Add a GenericFunc instance to the registry\n",
    "   * \\param func The GenericFunc instance\n",
    "   * \\param name The name of the registered GenericFunc\n",
    "   */\n",
    "  TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);\n",
    "\n",
    "  /*!\n",
    "   * \\brief access the internal node container\n",
    "   * \\return the pointer to the internal node container\n",
    "   */\n",
    "  inline GenericFuncNode* operator->();\n",
    "\n",
    "  // declare container type\n",
    "  using ContainerType = GenericFuncNode;\n",
    "\n",
    "  // Internal class.\n",
    "  struct Manager;\n",
    "\n",
    " private:\n",
    "  friend struct Manager;\n",
    "};\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了一个名为`GenericFunc`的类，它继承自`ObjectRef`。这个类表示一个通用函数，可以针对每个目标平台进行特化。\n",
    "\n",
    "该类中定义了一些成员函数和变量：\n",
    "\n",
    "- `set_default`函数用于设置默认函数实现，并返回对自身的引用。\n",
    "- `register_func`函数用于注册一个特化函数，并返回对自身的引用。\n",
    "- `operator()`函数用于通过直接传递未打包格式来调用通用函数。\n",
    "- `CallPacked`函数用于根据当前目标上下文调用相关函数，并将参数以打包格式传递。\n",
    "- `GetPacked`函数用于获取指定当前目标上下文的打包函数。\n",
    "- `Get`函数用于查找或注册给定名称的`GenericFunc`实例。\n",
    "- `RegisterGenericFunc`函数用于将`GenericFunc`实例添加到注册表中。\n",
    "- `operator->`运算符用于访问内部节点容器。\n",
    "- `ContainerType`类型别名表示内部节点容器的类型。\n",
    "- `Manager`结构体表示内部管理类。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312x",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
